# -*- coding: utf-8 -*-
"""
Created on Mon Mar  2 21:12:23 2026

@author: Wieczorek_W_Station
"""

import os
import re
import datetime as dt
import pickle
import random as rd
import numpy as np
import pandas as pd
import torch
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
## Load BERT model 
TF_ENABLE_ONEDNN_OPTS=0

import seaborn as sbs
from matplotlib import pyplot as plt

# =============================================================================
# initialize tokenizer and model
# =============================================================================
random_seed = 42
rd.seed(random_seed)

# Set a random seed for PyTorch (for GPU as well)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
# Load BERT tokenizer and model


tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

## Set the folder structure for further calculation
root = "C:\\Users\\Wieczorek_W_Station\\Dropbox\\Arbeit Kassel\\paperideen\\Moltbook_Science\\Data\\"
path = os.path.join(root,"Molts")
output = os.path.join(root,"Outputs")
figures = os.path.join(root, "Figures")

try:
    os.makedirs(output)
except:
    pass
#%%
# =============================================================================
# load Data 
# =============================================================================
os.chdir(path)

threadsDf = pd.read_csv("ThreadsAllSentiments.csv", sep = ";")\
    .drop(columns = "Unnamed: 0")
commentsDf = pd.read_csv("CommentsAllSentiments.csv", sep = ";")\
    .drop(columns = "Unnamed: 0")

#%%
# =============================================================================
# tokenize the and calculate mean lenghts of 
# =============================================================================

sumTokens = []
## tokenize the sequences
for t in tqdm(threadsDf.content):

    inputs = tokenizer(t,
        # max_length = max_size,
        return_tensors="pt",
        # padding="max_length"
        ).to(device)
    
    ## count the number of tokens, excluding all special tokens
    input_lenghts = []
    
    for i in inputs["input_ids"]:
        tokens = len([x for x in i if x not in [102,103,0]])
        input_lenghts.append(tokens)
    
    ## cumulate the number of tokens
    cum_tokens = np.cumsum(input_lenghts)
    sumTokens.append(int(cum_tokens))

threadsDf.insert(3,"NoTokens",sumTokens)

#%%
# =============================================================================
# Calculate descriptive statistics for threads / posts
# =============================================================================

## begin with a cumulative plot for the timestamps (daily basis)
threadsDf.created_at

date = pd.to_datetime(threadsDf.created_at).dt.date

threadsDf.insert(21, "date_created",date)

threads_created = threadsDf.\
    groupby(by = "date_created").size().reset_index(name = "value")
    
sbs.ecdfplot(data = threadsDf,
             x = "date_created",
             stat = "count")
#%%


threadsProperties = threadsDf[["NoTokens","upvotes","comment_count"]]
threadsProperties.columns = ["Tokens","Upvotes","Comments"]

threadsPropertiesLong = pd.melt(threadsProperties)

os.chdir(figures)
g = sbs.displot(threadsPropertiesLong,            
             x = "value",
             col = "variable",
             hue = "variable",
             height = 3,
             log_scale = True,
             col_wrap = 3)
g.set_titles("{col_name}")

plt.savefig("DescriptivesHistogram.png",
            dpi = 600)
plt.close()

#%%
## include sentiments into threadsProperties
sentiment_map = {"Very Negative":0, 
                 "Negative":1,
                 "Neutral":2,
                 "Positive":3,
                 "Very Positive":4}

threadsProperties.insert(3,
                         "sentiments",
                         [sentiment_map[s] for s in threadsDf.sentiments]
                         )

#%%
sentiments = threadsDf.groupby(by = "sentiments").size().reset_index(name = "values")
sbs.barplot(data = sentiments,
            x = "sentiments", 
            y = "values",
            order = sentiment_map.keys())
plt.savefig("SentimentsThreads.png",
            dpi = 600)
plt.close()

sentimentsPosts = commentsDf.groupby(by = "post_id")["sentiments"].mean().reset_index(name = "sentiments")
sbs.kdeplot(data = sentimentsPosts,
            x = "sentiments",
            fill = "sentiments")
plt.xticks(ticks = list(sentiment_map.values()), 
           labels = sentiment_map.keys())
plt.savefig("SentimentsComments.png",
            dpi = 600)
plt.close()

#%%

## create heatmap of correlation (pearson)
threadsPropertiesCorr = threadsProperties.corr().round(2)
g = sbs.heatmap(threadsPropertiesCorr, 
                annot =True,
                vmin = 0,
                vmax = 1,
                mask = np.eye(4))

yticklables = g.get_yticklabels()
g.set_yticklabels(labels = yticklables, rotation = 0)
plt.tight_layout()
plt.savefig("CorrelationsHeatmap.png",
            dpi =600)


#%%
## calculate the type and number of submolts of threads
submolts = [re.search("'name':.+,",m).group() for m in threadsDf.submolt]
submolts = [re.sub("(name|'|:|,)","",m).strip() for m in submolts]

threadsDf.insert(4,"submolts",submolts)

submoltsCount = threadsDf["submolts"]
submoltsCount = submoltsCount.groupby(submoltsCount).count()
submoltsCount = submoltsCount.sort_values(ascending = False)
submoltsCount = submoltsCount.reset_index(name = "count")



## toDo: explore Submolts
os.chdir(output)
submoltsCount.to_csv("submolts.csv", sep = ";")